{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 1.345 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/Cython/Compiler/Main.py:381: FutureWarning: Cython directive 'language_level' not set, using '3str' for now (Py3). This has changed from earlier releases! File: /home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/mindnlp/transformers/models/graphormer/algos_graphormer.pyx\n",
      "  tree = Parsing.p_module(s, pxd, full_module_name)\n",
      "In file included from /home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/include/numpy/ndarraytypes.h:1929,\n",
      "                 from /home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/include/numpy/ndarrayobject.h:12,\n",
      "                 from /home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/include/numpy/arrayobject.h:5,\n",
      "                 from /home/lvyufeng/.pyxbld/temp.linux-aarch64-cpython-311/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/mindnlp/transformers/models/graphormer/algos_graphormer.c:1240:\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h:17:2: warning: #warning \"Using deprecated NumPy API, disable it with \" \"#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION\" [-Wcpp]\n",
      "   17 | #warning \"Using deprecated NumPy API, disable it with \" \\\n",
      "      |  ^~~~~~~\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore.dataset import GeneratorDataset, transforms\n",
    "\n",
    "from mindnlp.engine import Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# prepare dataset\n",
    "class SentimentDataset:\n",
    "    \"\"\"Sentiment Dataset\"\"\"\n",
    "\n",
    "    def __init__(self, path):\n",
    "        self.path = path\n",
    "        self._labels, self._text_a = [], []\n",
    "        self._load()\n",
    "\n",
    "    def _load(self):\n",
    "        with open(self.path, \"r\", encoding=\"utf-8\") as f:\n",
    "            dataset = f.read()\n",
    "        lines = dataset.split(\"\\n\")\n",
    "        for line in lines[1:-1]:\n",
    "            label, text_a = line.split(\"\\t\")\n",
    "            self._labels.append(int(label))\n",
    "            self._text_a.append(text_a)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self._labels[index], self._text_a[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-11-22 16:04:29--  https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz\n",
      "Resolving baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)... 198.18.0.62\n",
      "Connecting to baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)|198.18.0.62|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 1710581 (1.6M) [application/x-gzip]\n",
      "Saving to: ‘emotion_detection.tar.gz’\n",
      "\n",
      "emotion_detection.t 100%[===================>]   1.63M  6.97MB/s    in 0.2s    \n",
      "\n",
      "2024-11-22 16:04:30 (6.97 MB/s) - ‘emotion_detection.tar.gz’ saved [1710581/1710581]\n",
      "\n",
      "data/\n",
      "data/test.tsv\n",
      "data/infer.tsv\n",
      "data/dev.tsv\n",
      "data/train.tsv\n",
      "data/vocab.txt\n"
     ]
    }
   ],
   "source": [
    "# download dataset\n",
    "!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz\n",
    "!tar xvf emotion_detection.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):\n",
    "    is_ascend = mindspore.get_context('device_target') == 'Ascend'\n",
    "\n",
    "    column_names = [\"label\", \"text_a\"]\n",
    "    \n",
    "    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)\n",
    "    # transforms\n",
    "    type_cast_op = transforms.TypeCast(mindspore.int32)\n",
    "    def tokenize_and_pad(text):\n",
    "        if is_ascend:\n",
    "            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)\n",
    "        else:\n",
    "            tokenized = tokenizer(text)\n",
    "        return tokenized['input_ids'], tokenized['attention_mask']\n",
    "    # map dataset\n",
    "    dataset = dataset.map(operations=tokenize_and_pad, input_columns=\"text_a\", output_columns=['input_ids', 'attention_mask'])\n",
    "    dataset = dataset.map(operations=[type_cast_op], input_columns=\"label\", output_columns='labels')\n",
    "    # # batch dataset\n",
    "    if is_ascend:\n",
    "        dataset = dataset.batch(batch_size)\n",
    "    else:\n",
    "        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),\n",
    "                                                         'attention_mask': (None, 0)})\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "昇腾NPU环境下暂不支持动态Shape，数据预处理部分采用静态Shape处理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import BertTokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.pad_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dataset_train = process_dataset(SentimentDataset(\"data/train.tsv\"), tokenizer)\n",
    "dataset_val = process_dataset(SentimentDataset(\"data/dev.tsv\"), tokenizer)\n",
    "dataset_test = process_dataset(SentimentDataset(\"data/test.tsv\"), tokenizer, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['input_ids', 'attention_mask', 'labels']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_train.get_col_names()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "mindspore.dataset.engine.datasets.BatchDataset"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(dataset_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': Tensor(shape=[32, 64], dtype=Int64, value=\n",
      "[[ 101, 6612, 2209 ...    0,    0,    0],\n",
      " [ 101, 2769, 7270 ...    0,    0,    0],\n",
      " [ 101, 5439, 2038 ...    0,    0,    0],\n",
      " ...\n",
      " [ 101, 6929, 2769 ...    0,    0,    0],\n",
      " [ 101, 2769, 6230 ...    0,    0,    0],\n",
      " [ 101, 2218, 5464 ...    0,    0,    0]]), 'attention_mask': Tensor(shape=[32, 64], dtype=Int64, value=\n",
      "[[1, 1, 1 ... 0, 0, 0],\n",
      " [1, 1, 1 ... 0, 0, 0],\n",
      " [1, 1, 1 ... 0, 0, 0],\n",
      " ...\n",
      " [1, 1, 1 ... 0, 0, 0],\n",
      " [1, 1, 1 ... 0, 0, 0],\n",
      " [1, 1, 1 ... 0, 0, 0]]), 'labels': Tensor(shape=[32], dtype=Int32, value= [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, \n",
      " 1, 1, 1, 1, 1, 1, 2, 1])}\n"
     ]
    }
   ],
   "source": [
    "print(next(dataset_train.create_dict_iterator()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import BertForSequenceClassification, BertModel\n",
    "\n",
    "# set bert config and define parameters for training\n",
    "model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from mindnlp.engine import TrainingArguments\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"bert_emotect_finetune\",\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    load_best_model_at_end=True,\n",
    "    num_train_epochs=3.0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp import evaluate\n",
    "import numpy as np\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=dataset_train,\n",
    "    eval_dataset=dataset_val,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                   | 1/906 [00:12<3:07:56, 12.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-\\|/-\\|/"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|██████████████████████▎                                            | 302/906 [03:52<04:28,  2.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.3296, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}\n",
      "-"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|                                                                              | 0/34 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|████                                                                  | 2/34 [00:00<00:08,  3.71it/s]\u001b[A\n",
      "  9%|██████▏                                                               | 3/34 [00:00<00:10,  2.94it/s]\u001b[A\n",
      " 18%|████████████▎                                                         | 6/34 [00:01<00:04,  6.69it/s]\u001b[A\n",
      " 26%|██████████████████▌                                                   | 9/34 [00:01<00:02, 10.18it/s]\u001b[A\n",
      " 35%|████████████████████████▎                                            | 12/34 [00:01<00:01, 13.22it/s]\u001b[A\n",
      " 44%|██████████████████████████████▍                                      | 15/34 [00:01<00:01, 15.79it/s]\u001b[A\n",
      " 53%|████████████████████████████████████▌                                | 18/34 [00:01<00:00, 17.86it/s]\u001b[A\n",
      " 62%|██████████████████████████████████████████▌                          | 21/34 [00:01<00:00, 19.34it/s]\u001b[A\n",
      " 71%|████████████████████████████████████████████████▋                    | 24/34 [00:01<00:00, 20.46it/s]\u001b[A\n",
      " 79%|██████████████████████████████████████████████████████▊              | 27/34 [00:01<00:00, 21.35it/s]\u001b[A\n",
      " 88%|████████████████████████████████████████████████████████████▉        | 30/34 [00:02<00:00, 21.85it/s]\u001b[A\n",
      "                                                                                                          \u001b[A\n",
      " 33%|██████████████████████▎                                            | 302/906 [03:55<04:28,  2.25it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████| 34/34 [00:02<00:00, 22.53it/s]\u001b[A\n",
      "                                                                                                          \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.17432598769664764, 'eval_accuracy': 0.9407407407407408, 'eval_runtime': 2.9683, 'eval_samples_per_second': 11.454, 'eval_steps_per_second': 1.684, 'epoch': 1.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 67%|███████████████████████████████████████████████████▎                         | 604/906 [06:20<02:09,  2.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.1979, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|                                                                                        | 0/34 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|████▋                                                                           | 2/34 [00:00<00:04,  7.61it/s]\u001b[A\n",
      "  9%|███████                                                                         | 3/34 [00:00<00:06,  5.04it/s]\u001b[A\n",
      " 12%|█████████▍                                                                      | 4/34 [00:00<00:07,  4.20it/s]\u001b[A\n",
      " 15%|███████████▊                                                                    | 5/34 [00:01<00:06,  4.22it/s]\u001b[A\n",
      " 24%|██████████████████▊                                                             | 8/34 [00:01<00:03,  8.11it/s]\u001b[A\n",
      " 32%|█████████████████████████▌                                                     | 11/34 [00:01<00:02, 11.38it/s]\u001b[A\n",
      " 41%|████████████████████████████████▌                                              | 14/34 [00:01<00:01, 13.97it/s]\u001b[A\n",
      " 50%|███████████████████████████████████████▌                                       | 17/34 [00:01<00:01, 15.90it/s]\u001b[A\n",
      " 59%|██████████████████████████████████████████████▍                                | 20/34 [00:01<00:00, 17.37it/s]\u001b[A\n",
      " 68%|█████████████████████████████████████████████████████▍                         | 23/34 [00:01<00:00, 18.49it/s]\u001b[A\n",
      " 76%|████████████████████████████████████████████████████████████▍                  | 26/34 [00:02<00:00, 19.22it/s]\u001b[A\n",
      " 85%|███████████████████████████████████████████████████████████████████▍           | 29/34 [00:02<00:00, 19.74it/s]\u001b[A\n",
      "                                                                                                                    \u001b[A\n",
      " 67%|███████████████████████████████████████████████████▎                         | 604/906 [06:23<02:09,  2.34it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████| 34/34 [00:02<00:00, 20.11it/s]\u001b[A\n",
      "                                                                                                                    \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.11939491331577301, 'eval_accuracy': 0.9546296296296296, 'eval_runtime': 2.9077, 'eval_samples_per_second': 11.693, 'eval_steps_per_second': 1.72, 'epoch': 2.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 906/906 [08:55<00:00,  2.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.1204, 'learning_rate': 0.0, 'epoch': 3.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|                                                                                        | 0/34 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|████▋                                                                           | 2/34 [00:00<00:04,  7.13it/s]\u001b[A\n",
      " 15%|███████████▊                                                                    | 5/34 [00:00<00:02, 13.80it/s]\u001b[A\n",
      " 26%|█████████████████████▏                                                          | 9/34 [00:00<00:01, 19.99it/s]\u001b[A\n",
      " 35%|███████████████████████████▉                                                   | 12/34 [00:00<00:00, 22.84it/s]\u001b[A\n",
      " 47%|█████████████████████████████████████▏                                         | 16/34 [00:00<00:00, 25.60it/s]\u001b[A\n",
      " 56%|████████████████████████████████████████████▏                                  | 19/34 [00:00<00:00, 26.60it/s]\u001b[A\n",
      " 65%|███████████████████████████████████████████████████                            | 22/34 [00:00<00:00, 27.16it/s]\u001b[A\n",
      " 74%|██████████████████████████████████████████████████████████                     | 25/34 [00:01<00:00, 27.59it/s]\u001b[A\n",
      " 82%|█████████████████████████████████████████████████████████████████              | 28/34 [00:01<00:00, 27.73it/s]\u001b[A\n",
      " 91%|████████████████████████████████████████████████████████████████████████       | 31/34 [00:01<00:00, 27.72it/s]\u001b[A\n",
      "                                                                                                                    \u001b[A\n",
      "100%|█████████████████████████████████████████████████████████████████████████████| 906/906 [08:57<00:00,  2.34it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 27.80it/s]\u001b[A\n",
      "                                                                                                                    \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.06156477704644203, 'eval_accuracy': 0.9814814814814815, 'eval_runtime': 1.7696, 'eval_samples_per_second': 19.214, 'eval_steps_per_second': 2.826, 'epoch': 3.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 906/906 [09:08<00:00,  1.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train_runtime': 548.775, 'train_samples_per_second': 52.83, 'train_steps_per_second': 1.651, 'train_loss': 0.215978719277624, 'epoch': 3.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=906, training_loss=0.215978719277624, metrics={'train_runtime': 548.775, 'train_samples_per_second': 52.83, 'train_steps_per_second': 1.651, 'train_loss': 0.215978719277624, 'epoch': 3.0})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# start training\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dataset_infer = SentimentDataset(\"data/infer.tsv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def predict(text, label=None):\n",
    "    label_map = {0: \"消极\", 1: \"中性\", 2: \"积极\"}\n",
    "\n",
    "    text_tokenized = Tensor([tokenizer(text).input_ids])\n",
    "    logits = model(text_tokenized)\n",
    "    predict_label = logits[0].asnumpy().argmax()\n",
    "    info = f\"inputs: '{text}', predict: '{label_map[predict_label]}'\"\n",
    "    if label is not None:\n",
    "        info += f\" , label: '{label_map[label]}'\"\n",
    "    print(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inputs: '我 要 客观', predict: '中性' , label: '中性'\n",
      "inputs: '靠 你 真是 说 废话 吗', predict: '消极' , label: '消极'\n",
      "inputs: '口嗅 会', predict: '中性' , label: '中性'\n",
      "inputs: '每次 是 表妹 带 窝 飞 因为 窝路痴', predict: '中性' , label: '中性'\n",
      "inputs: '别说 废话 我 问 你 个 问题', predict: '消极' , label: '消极'\n",
      "inputs: '4967 是 新加坡 那 家 银行', predict: '中性' , label: '中性'\n",
      "inputs: '是 我 喜欢 兔子', predict: '积极' , label: '积极'\n",
      "inputs: '你 写 过 黄山 奇石 吗', predict: '中性' , label: '中性'\n",
      "inputs: '一个一个 慢慢来', predict: '中性' , label: '中性'\n",
      "inputs: '我 玩 过 这个 一点 都 不 好玩', predict: '消极' , label: '消极'\n",
      "inputs: '网上 开发 女孩 的 QQ', predict: '中性' , label: '中性'\n",
      "inputs: '背 你 猜 对 了', predict: '中性' , label: '中性'\n",
      "inputs: '我 讨厌 你 ， 哼哼 哼 。 。', predict: '消极' , label: '消极'\n"
     ]
    }
   ],
   "source": [
    "from mindspore import Tensor\n",
    "\n",
    "for label, text in dataset_infer:\n",
    "    predict(text, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inputs: '家人们咱就是说一整个无语住了 绝绝子叠buff', predict: '中性'\n"
     ]
    }
   ],
   "source": [
    "predict(\"家人们咱就是说一整个无语住了 绝绝子叠buff\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
